import sys
import os
helper_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'sae-jax'))
if helper_path not in sys.path:
    sys.path.append(helper_path)

import jax
import jax.numpy as jnp
import json
from functools import partial
import numpy as np
from transformers import AutoTokenizer
from sae_save_load import (
    save_model, 
    load_model, 
    save_checkpoint, 
    load_checkpoint,
    save_metadata
)

GEMMA_MODEL_NAME = os.environ.get('GEMMA_MODEL_NAME', "gemma-2-2b")
print(f"GEMMA_MODEL_NAME: {GEMMA_MODEL_NAME}")
GEMMA_MODEL_NAME_FULL = f"google/{GEMMA_MODEL_NAME}"
print(f"GEMMA_MODEL_NAME_FULL: {GEMMA_MODEL_NAME_FULL}")

@jax.jit
def process_batch(params, batch, k):
    """
    Process a batch of inputs to get sparse codes and reconstructions.
    
    Args:
        params: Model parameters
        batch: Input batch
        k: Number of active units
        
    Returns:
        Tuple of (sparse_codes, reconstructions)
    """
    # Encoder forward pass
    batch_minus_bias = batch - params["tied_bias"]
    encoded = jnp.dot(batch_minus_bias, params["encoder"]["weights"]) + params["encoder"]["bias"]
    
    # Compute top-k mask
    sorted_latents = -jnp.sort(-jnp.abs(encoded), axis=-1)
    k_th_largest = jnp.expand_dims(sorted_latents[..., k - 1], axis=-1)
    topk_mask = jnp.abs(encoded) >= k_th_largest
    
    # Apply the top-k mask to get sparse codes
    sparse_codes = jnp.where(topk_mask, encoded, 0)
    
    # Decoder forward pass
    decoded = jnp.dot(sparse_codes, params["decoder"]["weights"]) + params["decoder"]["bias"]
    reconstructions = decoded + params["tied_bias"]
    
    return sparse_codes, reconstructions

def get_sparse_representations_and_reconstructions(model_params, inputs, k, batch_size=1024):
    """
    Get sparse representations and reconstructions for the entire dataset.
    
    Args:
        model_params: Model parameters
        inputs: Input data
        k: Number of active units
        batch_size: Batch size for processing
        
    Returns:
        Tuple of (sparse_codes, reconstructions)
    """
    # Process data in batches
    num_samples = inputs.shape[0]
    codes = []
    reconstructions = []
    
    # Create a partially applied function with fixed parameters
    batch_processor = partial(process_batch, model_params, k=k)
    
    for i in range(0, num_samples, batch_size):
        batch = inputs[i:min(i+batch_size, num_samples)]
        batch_code, batch_reconstruction = batch_processor(batch)
        codes.append(np.array(batch_code))
        reconstructions.append(np.array(batch_reconstruction))
        if (i + batch_size) % (10 * batch_size) == 0:
            print(f"Processed {min(i+batch_size, num_samples)}/{num_samples} samples")
    
    return np.concatenate(codes), np.concatenate(reconstructions)

if __name__ == "__main__":
    # Load model
    model = load_model(f'~/{GEMMA_MODEL_NAME}-sae/k5_final_sae_model.pkl')
    print("Model loaded!")

    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(GEMMA_MODEL_NAME_FULL)
    # Get the vocabulary as a dict: token -> ID
    vocab_dict = tokenizer.get_vocab()
    g = jnp.load(f'~/unembeddings/{GEMMA_MODEL_NAME}/clean_unembeddings.npy')

    vocab_list = ["<unused>"] * (max(vocab_dict.values()) + 1)
    for word, index in vocab_dict.items():
        vocab_list[index] = word

    g = g * jnp.sqrt(g.shape[0] / g.shape[1]) # set the norms to be close to 1

    print(jnp.average(jnp.linalg.norm(g, axis=1)))
    print(max(vocab_dict.values()))
    print(g.shape)

    # Get sparse representations and reconstructions
    z, g_sparse = get_sparse_representations_and_reconstructions(
        model.params, 
        g, 
        k=model.k
    )

    print(f"Sparse codes shape: {z.shape}")
    print(f"Reconstructions shape: {g_sparse.shape}")

    # Calculate sparsity metrics
    sparsity = (z == 0).mean()
    avg_active = (z != 0).sum(axis=1).mean()
    print(f"Sparsity: {sparsity:.4f} (fraction of zeros)")
    print(f"Average active units per sample: {avg_active:.2f} out of {z.shape[1]}")

    # Save the sparse codes
    np.save(f'~/{GEMMA_MODEL_NAME}-sae/k5_whole_sae_final_z.npy', z)